# Functions for running mixture IRT models stripped from mixture_irt.Rmd

#BiocManager::install("aroma.light")
library(gtools)
library(Rcpp)
library(aroma.light)


## Rcpp-based Function....

cppFunction("
        NumericVector irt_marg_lik(NumericMatrix y, NumericMatrix x,
                                   NumericVector a, NumericMatrix b,
                                   double ptrim) {
          NumericVector mlik(y.nrow());
          NumericVector row_log_lik(y.nrow());
          double phat, xb;
          
          for (int k=0; k<y.nrow(); k++) mlik(k) = 0.0;
          for (int i=0; i<x.nrow(); i++) {
            for (int k=0; k<y.nrow(); k++) row_log_lik(k) = 0.0;
              for (int j=0; j<y.ncol(); j++) {
                xb = 0.0;
                for (int d=0; d<x.ncol(); d++) xb += x(i,d)*b(d,j);
                phat = 1.0/(1.0+exp(-a[j]-xb));
                phat = phat < ptrim ? ptrim : phat;
                for (int k=0; k<y.nrow(); k++) {
                  if (!NumericMatrix::is_na(y(k,j)))
                    row_log_lik(k) += y(k,j)==1 ? log(phat) : log(1.0-phat);
                }
              }
            for (int k=0; k<y.nrow(); k++) mlik(k) += exp(row_log_lik(k));
          }
          for (int k=0; k<y.nrow(); k++) {
            mlik(k) = log(mlik(k)) - log(x.nrow());          
          }
          return mlik;
        }
")


## Simulated data generators

ind_vote_data <- function(n=100, k=10, p = rbeta(k, 5, 5)) {
  z <- matrix(runif(n*k), n, k)
  y <- t(apply(z, 1, function(zz) p>zz) + 0)
  list(y=y, p=p)
}

irt_dat <- function(n=2000, k=10, ndim=3) {
  x <- matrix(rnorm(n*ndim), n, ndim)
  b <- matrix(runif(ndim*k, -3, 3), ndim, k)
  a <- runif(k, -3, 3)
  phat <- plogis( cbind(1,x)  %*% rbind(a,b) )
  y <- ifelse(matrix(runif(n*k), n, k) < phat, 1, 0)
  list(y=y, phat=phat, x=x, b=b, a=a)
}

mix_data <- function(items = 50, ndim=1, sim_ind_n = 500, sim_flip_n = 200, sim_irt_n = 1300,
                     equal_marginals=FALSE) {
  irt_part <- irt_dat(n=sim_irt_n, ndim=ndim, k=items)
  pp <- if (equal_marginals) apply(irt_part$y, 2, mean, na.rm=TRUE) else rbeta(items, 3, 3)
  ind_part <- ind_vote_data(n=sim_ind_n, k=items, p = pp) 
  flip_part <- ind_vote_data(n=sim_flip_n, k=items, p = rep(1/2, items))
  dat <- rbind( flip_part$y, ind_part$y, irt_part$y )
  w <- c(rep("flip", sim_flip_n), rep("ind", sim_ind_n), rep("irt", sim_irt_n))
  return( list(w=w, dat=dat, ind_par=ind_part$p, irt_par=list(b=irt_part$b, a=irt_part$a, x=irt_part$x)) )
}


## Likelihoods...

vlog <-function(a,b) ifelse(b>0, a*log(b), 0.0)

flip_vote_loglik <- function(y) {
  ivp <- rep(0.5, NCOL(y))
  lk <- apply(y, 1, 
              function(yy) 
                sum(vlog(yy, ivp) + 
                      vlog(1-yy, 1-ivp), 
                    na.rm=TRUE))
  return(list(lk=lk))
}

ind_vote_est <- function(y, w) {
  ivp <- apply(y, 2, weighted.mean, w=w, na.rm=TRUE)
  lk <- apply(y, 1, 
           function(yy) 
              sum(vlog(yy, ivp) + 
                  vlog(1-yy, 1-ivp), 
                  na.rm=TRUE))
  return(list(ivp=ivp, lk=lk))
}


qm_irt <- function(y, w=NULL, iters=10, ndim=3, starts=NULL) {
   # if weight on IRT model is 0 return a set of null results...
   if (!is.null(w) & max(w)==0) {
      return(list(x=NA, a=NA, b=NA, lk=NA, lt=NA, svd=NA))
   }
   q <- 2*y-1
   if (is.null(starts)) {
     z <- ifelse(is.na(y), 0, 4*(y-0.5))
     x <- rep(Inf, NROW(z))
   }
   else {
     x <- starts$x
     aa <- cbind(1,starts$x) %*% rbind(starts$a,starts$b) 
     pr <- plogis(aa)
     z <- ifelse(is.na(y), aa, aa + 4*(y-pr))
   }
   for (i in 1:iters) {
     x_prev <- x
     res <- wpca(z, w=w, center=TRUE)
     a <- res$xMean
     x <- res$pc[,1:ndim]
     b <- res$vt[1:ndim,]
     aa <- cbind(1,x) %*% rbind(a,b) 
     pr <- plogis(aa)
     if (is.null(w)) {
       lk <- vlog(y, pr) + vlog(1-y, 1-pr)
     }
     else {
       lk <- w*(vlog(y, pr) + vlog(1-y, 1-pr))
     }
     lt <- sum(lk[which(!is.na(y))])
     z <- ifelse(is.na(y), aa, aa + 4*(y-pr))
     rmsd <- sqrt(sum((x-x_prev)^2))/sd(x)
     if (i==1 | (i %% 10 == 0) | i==iters) cat( sprintf("Iteration: %i (%9.4f, %6.4f)\n", i, lt, rmsd))
   }
   return(list(x=x, a=a, b=b, 
               lk=apply(lk, 1, sum, na.rm=TRUE)/w, 
               lt=lt,
               svd=res$d[1:ndim]))
}

em_mix_irt <- function(y, iters=40, ndim=1, irt_iters=50, xsamp=1000, 
                       w_starts=NULL, w_alpha=c(5,1,1), equal_marginals=FALSE) {
  r <- NULL
  if (is.null(w_starts)) {
    w <- rdirichlet(NROW(y), alpha=w_alpha)
  }
  else {
    w <- w_starts
  }
  p <- apply(w, 2, mean)
  for (i in 1:iters) {
    cat(sprintf("Outer iteration: %i\n", i))
    r <- qm_irt(y, w=w[,1], ndim=ndim, iter=irt_iters, starts=r)
    if (ndim==1) {
      r$x <- as.matrix(r$x)
      r$b <- t(r$b)
    }
    ww <- if (equal_marginals) (1-w[,3]) else w[,2] 
    ivp <- ind_vote_est(y, w=ww)
    fvp <- flip_vote_loglik(y) # Don't need weights for flip likelihood...
    if (max(w[,1])>0) {
      x_idx <- sample(1:NROW(r$x), size=xsamp, prob=w[,1], replace=TRUE)
      x_sample <- as.matrix(r$x[x_idx,])
      irt_mlk <- irt_marg_lik(y=y, x=x_sample, a=r$a, b=r$b, ptrim=0.001)
    } else {
      irt_mlk <- 1
    }
    w[,1] <- p[1]/(p[1] + p[2]*exp(ivp$lk - irt_mlk) + p[3]*exp(fvp$lk - irt_mlk))
    w[,2] <- p[2]*exp(ivp$lk - irt_mlk)/
              (p[1] + p[2]*exp(ivp$lk - irt_mlk) + p[3]*exp(fvp$lk - irt_mlk))
    w[,3] <- 1.0 - w[,1] - w[,2]
    p <- apply(w, 2, mean)
    cat(sprintf("Group probs = (%6.4f, %6.4f, %6.4f)\n\n",  p[1], p[2], p[3]))
  }
  return(list(w=w, ivp=ivp, fvp=fvp, irt=r, irt_mlk=irt_mlk))
}

perplexity <- function(y, fit, xsamp=1000) {
  w <- apply(fit$w, 2, mean)
  if (w[1]>0) {
    x_idx <- sample(1:NROW(fit$irt$x), size=xsamp, prob=fit$w[,1], replace=TRUE)
    x_sample <- as.matrix(fit$irt$x[x_idx,])
    irt_mlk <- irt_marg_lik(y=y, x=x_sample, a=fit$irt$a, b=fit$irt$b, ptrim=0.001)
  }
  else {
    irt_mlk <- 0
  }
  flk <- flip_vote_loglik(y)$lk
  ilk <- apply(y, 1, 
               function(yy) 
                 sum(vlog(yy, fit$ivp$ivp) + vlog(1-yy, 1-fit$ivp$ivp), 
                          na.rm=TRUE))
  ml <- log(w[1]*exp(irt_mlk) + w[2]*exp(ilk) + w[3]*exp(flk))
  n <- apply(y, 1, function(x) sum(!is.na(x)))
  perp <- exp(-ml/n)
  return(list(loglik=sum(ml), perp=mean(perp), avg_n = mean(n)))
}

cv_perplexity <- function(y, fit, folds=5, iters=40, irt_iters=50, xsamp=1000) {
  insample <- perplexity(y, fit)
  ndim <- NCOL(fit$irt$x)
  n <- NROW(y)
  fold_idx <- split(1:n, cut(sample(1:n),folds))
  names(fold_idx) <- sprintf("Fold %i", 1:folds)
  map_df(fold_idx, function(idx) {
    res <- em_mix_irt(y[-idx,], ndim=ndim, iters=iters, irt_iters=irt_iters, 
                      xsamp=xsamp, w_starts=fit$w[-idx,])
    perplexity(y[idx,], res, xsamp=xsamp) %>% 
      as_tibble()
  }) %>%
  summarize_all(mean) %>%
    bind_rows(insample) %>%
    mutate(estimate = c("CV", "In sample")) %>%
    select(estimate, avg_n, perp, loglik)
}


# Try it out...
#dat <- mix_data(items = 30, ndim=1, sim_ind_n = 25000, sim_flip_n = 5000, sim_irt_n = 70000) 
#res1 <- em_mix_irt(dat$dat, xsamp=1000, ndim=1, w_alpha=c(10,10,1), iters=50, equal_marginals=TRUE)
#res2 <- em_mix_irt(dat$dat, xsamp=1000, ndim=1, w_alpha=c(10,10,1), iters=50)
